{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tqrD7Yzlmlsk"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2k8X1C1nmpKv"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32xflLc4NTx-"
      },
      "source": [
        "# Custom Federated Algorithms, Part 2: Implementing Federated Averaging"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jtATV6DlqPs0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_2\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/custom_federated_algorithms_2.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/tutorials/custom_federated_algorithms_2.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/custom_federated_algorithms_2.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_igJ2sfaNWS8"
      },
      "source": [
        "This tutorial is the second part of a two-part series that demonstrates how to\n",
        "implement custom types of federated algorithms in TFF using the\n",
        "[Federated Core (FC)](../federated_core.md), which serves as a foundation for\n",
        "the [Federated Learning (FL)](../federated_learning.md) layer (`tff.learning`).\n",
        "\n",
        "We encourage you to first read the\n",
        "[first part of this series](custom_federated_algorithms_1.ipynb), which\n",
        "introduce some of the key concepts and programming abstractions used here.\n",
        "\n",
        "This second part of the series uses the mechanisms introduced in the first part\n",
        "to implement a simple version of federated training and evaluation algorithms.\n",
        "\n",
        "We encourage you to review the\n",
        "[image classification](federated_learning_for_image_classification.ipynb) and\n",
        "[text generation](federated_learning_for_text_generation.ipynb) tutorials for a\n",
        "higher-level and more gentle introduction to TFF's Federated Learning APIs, as\n",
        "they will help you put the concepts we describe here in context."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cuJuLEh2TfZG"
      },
      "source": [
        "## Before we start\n",
        "\n",
        "Before we start, try to run the following \"Hello World\" example to make sure\n",
        "your environment is correctly setup. If it doesn't work, please refer to the\n",
        "[Installation](../install.md) guide for instructions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rB1ovcX1mBxQ"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "!pip install --quiet --upgrade nest-asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-skNC6aovM46"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "\n",
        "# TODO(b/148678573,b/148685415): must use the reference context because it\n",
        "# supports unbounded references and tff.sequence_* intrinsics.\n",
        "tff.backends.reference.set_reference_context()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zzXwGnZamIMM"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'Hello, World!'"
            ]
          },
          "execution_count": 4,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "@tff.federated_computation\n",
        "def hello_world():\n",
        "  return 'Hello, World!'\n",
        "\n",
        "hello_world()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iu5Gd8D6W33s"
      },
      "source": [
        "## Implementing Federated Averaging\n",
        "\n",
        "As in\n",
        "[Federated Learning for Image Classification](federated_learning_for_image_classification.ipynb),\n",
        "we are going to use the MNIST example, but since this is intended as a low-level\n",
        "tutorial, we are going to bypass the Keras API and `tff.simulation`, write raw\n",
        "model code, and construct a federated data set from scratch.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b6qCjef350c_"
      },
      "source": [
        "### Preparing federated data sets\n",
        "\n",
        "For the sake of a demonstration, we're going to simulate a scenario in which we\n",
        "have data from 10 users, and each of the users contributes knowledge how to\n",
        "recognize a different digit. This is about as\n",
        "non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables)\n",
        "as it gets.\n",
        "\n",
        "First, let's load the standard MNIST data:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uThZM4Ds-KDQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "11493376/11490434 [==============================] - 0s 0us/step\n",
            "11501568/11490434 [==============================] - 0s 0us/step\n"
          ]
        }
      ],
      "source": [
        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PkJc5rHA2no_"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]"
            ]
          },
          "execution_count": 6,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "[(x.dtype, x.shape) for x in mnist_train]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mFET4BKJFbkP"
      },
      "source": [
        "The data comes as Numpy arrays, one with images and another with digit labels, both\n",
        "with the first dimension going over the individual examples. Let's write a\n",
        "helper function that formats it in a way compatible with how we feed federated\n",
        "sequences into TFF computations, i.e., as a list of lists - the outer list\n",
        "ranging over the users (digits), the inner ones ranging over batches of data in\n",
        "each client's sequence. As is customary, we will structure each batch as a pair\n",
        "of tensors named `x` and `y`, each with the leading batch dimension. While at\n",
        "it, we'll also flatten each image into a 784-element vector and rescale the\n",
        "pixels in it into the `0..1` range, so that we don't have to clutter the model\n",
        "logic with data conversions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XTaTLiq5GNqy"
      },
      "outputs": [],
      "source": [
        "NUM_EXAMPLES_PER_USER = 1000\n",
        "BATCH_SIZE = 100\n",
        "\n",
        "\n",
        "def get_data_for_digit(source, digit):\n",
        "  output_sequence = []\n",
        "  all_samples = [i for i, d in enumerate(source[1]) if d == digit]\n",
        "  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):\n",
        "    batch_samples = all_samples[i:i + BATCH_SIZE]\n",
        "    output_sequence.append({\n",
        "        'x':\n",
        "            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],\n",
        "                     dtype=np.float32),\n",
        "        'y':\n",
        "            np.array([source[1][i] for i in batch_samples], dtype=np.int32)\n",
        "    })\n",
        "  return output_sequence\n",
        "\n",
        "\n",
        "federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]\n",
        "\n",
        "federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xpNdBimWaMHD"
      },
      "source": [
        "As a quick sanity check, let's look at the `Y` tensor in the last batch of data\n",
        "contributed by the fifth client (the one corresponding to the digit `5`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bTNuL1W4bcuc"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,\n",
              "       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)"
            ]
          },
          "execution_count": 8,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "federated_train_data[5][-1]['y']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xgvcwv7Obhat"
      },
      "source": [
        "Just to be sure, let's also look at the image corresponding to the last element of that batch."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cI4aat1za525"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEBZJREFUeJzt3W1sU2Ufx/HfHgJKAOOQzo6hY9mE\nyTbINrMYZQjIgyaOh4Vn4mCEJeALA6hB34AJgSVCggmoVIgsMWogymZAFoiAyoxZivQFSAxOFsco\nG8gMAxU2PPcLc++WW3q1dH2C6/tJTrL239Pzz9l+O+25TnslOY7jCIB1kuPdAID4IPyApQg/YCnC\nD1iK8AOWIvyApQg/YCnCD1iK8AOWSo3lxpKSkmK5OcBKoV6026cjf0NDg0aOHKmcnBzV1NT05akA\nxJoTpp6eHic7O9tpbm52rl+/7hQWFjqnTp0yriOJhYUlykuowj7yNzU1KScnR9nZ2erXr5/mzZun\n+vr6cJ8OQIyFHf62tjYNHz6893ZmZqba2tr+9TiPx6OSkhKVlJSEuykAURD2Cb/bnVS43Qm96upq\nVVdXB6wDiI+wj/yZmZlqbW3tvX3u3DllZGREpCkAMRDuCb/u7m5nxIgRzs8//9x7wu/kyZOc8GNh\nifMSqrBf9qempmrr1q2aOnWqbt68qaqqKo0ePTrcpwMQY0nO7d68R2tjvOcHoi7USHN5L2Apwg9Y\nivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Apwg9YivADliL8gKUIP2Ap\nwg9YivADliL8gKUIP2CpmE7RjcTz0ksvGevXrl0z1nft2hXBbm6VlZVlrCcnm49dc+bMCVgbNmyY\ncd0VK1YY688++6yxfuTIEWM9EXDkByxF+AFLEX7AUoQfsBThByxF+AFLEX7AUn0a58/KytKgQYOU\nkpKi1NRUeb3eSPWFGJkxY4axPmHCBGN9yJAhxrrP5wtYW7BggXHdRYsWGespKSnGel9cvXrVWO/s\n7IzatmOlzxf5HDlyRA899FAkegEQQ7zsByzVp/AnJSVpypQpKi4ulsfjiVRPAGKgTy/7GxsblZGR\noY6ODk2ePFmjRo1SWVnZLY/xeDz8YwASUJ+O/BkZGZIkl8ulmTNnqqmp6V+Pqa6ultfr5WQgkGDC\nDv+1a9fU1dXV+/PBgweVn58fscYARFfYL/vb29s1c+ZMSVJPT48WLFigadOmRawxANGV5DiOE7ON\nJSXFalMI0aFDh4z1YOP8wX6nMfzzuiMrV6401g8cOGCs//TTT5FsJ6JC3ecM9QGWIvyApQg/YCnC\nD1iK8AOWIvyApfjq7nuAabjtqaeeMq47fvz4SLcTsj/++MNY/+9FZIE0NDQY6+vXrw9YO3v2rHHd\nRB2ijCSO/IClCD9gKcIPWIrwA5Yi/IClCD9gKcIPWIqP9N4DBg4cGLD222+/RXXbN27cMNY///zz\ngLVNmzYZ1+Xbn8LDR3oBGBF+wFKEH7AU4QcsRfgBSxF+wFKEH7AUn+e/B8yePTtu216xYoWxvmvX\nrtg0gjvGkR+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsFHeevqqrSvn375HK5dPLkSUnS5cuXNXfu\nXLW0tCgrK0u7d+/Wgw8+GPVmbTVnzhxjfcuWLVHb9jvvvGOsM45/9wp65F+8ePG/JkeoqanRpEmT\ndObMGU2aNEk1NTVRaxBAdAQNf1lZmdLS0m65r76+XpWVlZKkyspK1dXVRac7AFET1nv+9vZ2ud1u\nSZLb7VZHR0dEmwIQfVG/tt/j8cjj8UR7MwDuUFhH/vT0dPn9fkmS3++Xy+UK+Njq6mp5vV6+jBFI\nMGGFv7y8XLW1tZKk2tpaTZ8+PaJNAYi+oOGfP3++nnzySf3444/KzMzUzp07tWbNGh06dEi5ubk6\ndOiQ1qxZE4teAUQQ39ufAAYMGGCsf/vtt8Z6fn5+2Ns+fPiwsV5RUWGsd3V1hb1tRAff2w/AiPAD\nliL8gKUIP2Apwg9YivADluKru2Ogf//+xvr27duN9b4M5QWzceNGY52hvHsXR37AUoQfsBThByxF\n+AFLEX7AUoQfsBThByzFOH8MPPPMM8b6/PnzY9PIbcyaNctYLywsNNavXLlirH/wwQd33BNigyM/\nYCnCD1iK8AOWIvyApQg/YCnCD1iK8AOW4qu7Y2D//v3G+rRp02LUSeQlJ5uPH/X19QFrwfbLzp07\njfW//vrLWLcVX90NwIjwA5Yi/IClCD9gKcIPWIrwA5Yi/IClgo7zV1VVad++fXK5XDp58qQkad26\ndXr//fc1dOhQSdKGDRv0/PPPB9+YpeP8RUVFxvq7775rrBcXF4e97dOnTxvrfr/fWB8+fLix/thj\njxnrfbmMZM2aNcb6pk2bwn7ue1nExvkXL16shoaGf92/cuVK+Xw++Xy+kIIPILEEDX9ZWZnS0tJi\n0QuAGAr7Pf/WrVtVWFioqqoqdXZ2RrInADEQVviXL1+u5uZm+Xw+ud1urV69OuBjPR6PSkpKVFJS\nEnaTACIvrPCnp6crJSVFycnJWrZsmZqamgI+trq6Wl6vV16vN+wmAUReWOH/5xnivXv3RnUWWQDR\nEfSru+fPn6+jR4/q0qVLyszM1JtvvqmjR4/K5/MpKSlJWVlZQaeYBpB4+Dx/AhgwYICxnp2dHfZz\nt7W1GevBTtYOGTLEWB85cqSx/vrrrwesPffcc8Z1b968aazPmDHDWD9w4ICxfq/i8/wAjAg/YCnC\nD1iK8AOWIvyApQg/YCmG+iLg/vvvN9b//PNPYz2Gv4KYS0lJCVjz+XzGdfPy8oz1xsZGY338+PHG\n+r2KoT4ARoQfsBThByxF+AFLEX7AUoQfsBThBywV9PP8+NsDDzwQsPbRRx8Z1509e7ax/vvvv4fV\n091g4MCBAWv33Xdfn547NZU/377gyA9YivADliL8gKUIP2Apwg9YivADliL8gKUYKA2RabqxqVOn\nGtcNNo11sM+1JzLTOL4kffjhhwFrI0aMiHQ7uAMc+QFLEX7AUoQfsBThByxF+AFLEX7AUoQfsFTQ\ncf7W1la9+OKLunDhgpKTk1VdXa2XX35Zly9f1ty5c9XS0qKsrCzt3r1bDz74YCx6vus0NDQY66Zp\nrCVpz549kWznjixevNhYX7t2rbHel7+J7u5uY/29994L+7kRwpE/NTVVmzdv1unTp/Xdd99p27Zt\n+uGHH1RTU6NJkybpzJkzmjRpkmpqamLRL4AICRp+t9utoqIiSdKgQYOUl5entrY21dfXq7KyUpJU\nWVmpurq66HYKIKLu6D1/S0uLTpw4odLSUrW3t8vtdkv6+x9ER0dHVBoEEB0hX9t/9epVVVRUaMuW\nLRo8eHDIG/B4PPJ4PGE1ByB6Qjryd3d3q6KiQgsXLtSsWbMkSenp6fL7/ZIkv98vl8t123Wrq6vl\n9Xrl9Xoj1DKASAgafsdxtHTpUuXl5WnVqlW995eXl6u2tlaSVFtbq+nTp0evSwARF3SK7mPHjmnc\nuHEqKChQcvLf/ys2bNig0tJSzZkzR7/88oseeeQR7dmzR2lpaeaN3cVTdJeWlgasHT582Lhu//79\nI91Owgj2OzX9eXV2dhrXDTYEumPHDmPdVqFO0R30Pf/TTz8d8Mm+/PLLO+sKQMLgCj/AUoQfsBTh\nByxF+AFLEX7AUoQfsFTQcf6IbuwuHuc3WbJkibEe7KOnKSkpkWwnpoL9Ti9evBiwVlFRYVy3sbEx\nrJ5sF2qkOfIDliL8gKUIP2Apwg9YivADliL8gKUIP2ApxvljYNSoUcb6Z599ZqwHm+I7moJNH75v\n3z5j3XSNw4ULF8LqCWaM8wMwIvyApQg/YCnCD1iK8AOWIvyApQg/YCnG+YF7DOP8AIwIP2Apwg9Y\nivADliL8gKUIP2Apwg9YKmj4W1tbNWHCBOXl5Wn06NF6++23JUnr1q3TsGHDNHbsWI0dO1ZffPFF\n1JsFEDlBL/Lx+/3y+/0qKipSV1eXiouLVVdXp927d2vgwIF65ZVXQt8YF/kAURfqRT6pwR7gdrvl\ndrslSYMGDVJeXp7a2tr61h2AuLuj9/wtLS06ceKESktLJUlbt25VYWGhqqqq1NnZedt1PB6PSkpK\nVFJS0vduAUSOE6Kuri6nqKjI+fTTTx3HcZwLFy44PT09zs2bN5033njDWbJkSdDnkMTCwhLlJVQh\nPfLGjRvOlClTnM2bN9+2fvbsWWf06NGEn4UlAZZQBX3Z7ziOli5dqry8PK1atar3fr/f3/vz3r17\nlZ+fH+ypACSQoGf7jx07pnHjxqmgoEDJyX//r9iwYYM+/vhj+Xw+JSUlKSsrS9u3b+89MRhwY5zt\nB6IuSKR78Xl+4B4TaqS5wg+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsRfsBShB+wFOEHLEX4AUsR\nfsBShB+wFOEHLBX0CzwjaciQIcrKyuq9ffHiRQ0dOjSWLYQsUXtL1L4kegtXJHtraWkJ+bEx/Tz/\n/yspKZHX643X5o0StbdE7Uuit3DFqzde9gOWIvyApVLWrVu3Lp4NFBcXx3PzRonaW6L2JdFbuOLR\nW1zf8wOIH172A5aKS/gbGho0cuRI5eTkqKamJh4tBJSVlaWCggKNHTs27lOMVVVVyeVy3TInwuXL\nlzV58mTl5uZq8uTJAadJi0dviTJzc6CZpeO97xJuxuuQp/eIkJ6eHic7O9tpbm52rl+/7hQWFjqn\nTp2KdRsBPfroo87Fixfj3YbjOI7z1VdfOcePH79lNqRXX33V2bhxo+M4jrNx40bntddeS5je1q5d\n67z11ltx6eefzp8/7xw/ftxxHMe5cuWKk5ub65w6dSru+y5QX/HabzE/8jc1NSknJ0fZ2dnq16+f\n5s2bp/r6+li3cVcoKytTWlraLffV19ersrJSklRZWam6urp4tHbb3hKF2+1WUVGRpFtnlo73vgvU\nV7zEPPxtbW0aPnx47+3MzMyEmvI7KSlJU6ZMUXFxsTweT7zb+Zf29vbemZHcbrc6Ojri3NGtQpm5\nOZb+ObN0Iu27cGa8jrSYh9+5zeBCIs3k09jYqO+//14HDhzQtm3b9PXXX8e7pbvG8uXL1dzcLJ/P\nJ7fbrdWrV8e1n6tXr6qiokJbtmzR4MGD49rLP/1/X/HabzEPf2ZmplpbW3tvnzt3ThkZGbFuI6D/\n9uJyuTRz5kw1NTXFuaNbpaen906S6vf75XK54tzR/6SnpyslJUXJyclatmxZXPddd3e3KioqtHDh\nQs2aNau3v3jvu0B9xWO/xTz8TzzxhM6cOaOzZ8/qxo0b+uSTT1ReXh7rNm7r2rVr6urq6v354MGD\nCTf7cHl5uWprayVJtbW1mj59epw7+p9EmbnZCTCzdLz3XaC+4rbfYn6K0XGc/fv3O7m5uU52draz\nfv36eLRwW83NzU5hYaFTWFjoPP7443Hvbd68ec7DDz/spKamOsOGDXN27NjhXLp0yZk4caKTk5Pj\nTJw40fn1118TprdFixY5+fn5TkFBgfPCCy8458+fj0tv33zzjSPJKSgocMaMGeOMGTPG2b9/f9z3\nXaC+4rXfuMIPsBRX+AGWIvyApQg/YCnCD1iK8AOWIvyApQg/YCnCD1jqP1hNIrYb+rn+AAAAAElF\nTkSuQmCC\n",
            "text/plain": [
              "<Figure size 600x400 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "from matplotlib import pyplot as plt\n",
        "\n",
        "plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J-ox58PA56f8"
      },
      "source": [
        "### On combining TensorFlow and TFF\n",
        "\n",
        "In this tutorial, for compactness we immediately decorate functions that\n",
        "introduce TensorFlow logic with `tff.tf_computation`. However, for more complex\n",
        "logic, this is not the pattern we recommend. Debugging TensorFlow can already be\n",
        "a challenge, and debugging TensorFlow after it has been fully serialized and\n",
        "then re-imported necessarily loses some metadata and limits interactivity,\n",
        "making debugging even more of a challenge.\n",
        "\n",
        "Therefore, **we strongly recommend writing complex TF logic as stand-alone\n",
        "Python functions** (that is, without `tff.tf_computation` decoration). This way\n",
        "the TensorFlow logic can be developed and tested using TF best practices and\n",
        "tools (like eager mode), before serializing the computation for TFF (e.g., by invoking `tff.tf_computation` with a Python function as the argument)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RSd6UatXbzw-"
      },
      "source": [
        "### Defining a loss function\n",
        "\n",
        "Now that we have the data, let's define a loss function that we can use for\n",
        "training. First, let's define the type of input as a TFF named tuple. Since the\n",
        "size of data batches may vary, we set the batch dimension to `None` to indicate\n",
        "that the size of this dimension is unknown."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "653xv5NXd4fy"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'<x=float32[?,784],y=int32[?]>'"
            ]
          },
          "execution_count": 10,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "BATCH_SPEC = collections.OrderedDict(\n",
        "    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n",
        "    y=tf.TensorSpec(shape=[None], dtype=tf.int32))\n",
        "BATCH_TYPE = tff.to_type(BATCH_SPEC)\n",
        "\n",
        "str(BATCH_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pb6qPUvyh5A1"
      },
      "source": [
        "You may be wondering why we can't just define an ordinary Python type. Recall\n",
        "the discussion in [part 1](custom_federated_algorithms_1.ipynb), where we\n",
        "explained that while we can express the logic of TFF computations using Python,\n",
        "under the hood TFF computations *are not* Python. The symbol `BATCH_TYPE`\n",
        "defined above represents an abstract TFF type specification. It is important to\n",
        "distinguish this *abstract* TFF type from concrete Python *representation*\n",
        "types, e.g., containers such as `dict` or `collections.namedtuple` that may be\n",
        "used to represent the TFF type in the body of a Python function. Unlike Python,\n",
        "TFF has a single abstract type constructor `tff.StructType` for tuple-like\n",
        "containers, with elements that can be individually named or left unnamed. This\n",
        "type is also used to model formal parameters of computations, as TFF\n",
        "computations can formally only declare one parameter and one result - you will\n",
        "see examples of this shortly.\n",
        "\n",
        "Let's now define the TFF type of model parameters, again as a TFF named tuple of\n",
        "*weights* and *bias*."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Og7VViafh-30"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<weights=float32[784,10],bias=float32[10]>\n"
          ]
        }
      ],
      "source": [
        "MODEL_SPEC = collections.OrderedDict(\n",
        "    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),\n",
        "    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))\n",
        "MODEL_TYPE = tff.to_type(MODEL_SPEC)\n",
        "\n",
        "print(MODEL_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iHhdaWSpfQxo"
      },
      "source": [
        "With those definitions in place, now we can define the loss for a given model, over a single batch. Note the usage of `@tf.function` decorator inside the `@tff.tf_computation` decorator. This allows us to write TF using Python like semantics even though were inside a `tf.Graph` context created by the `tff.tf_computation` decorator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EObiz_Ke0uK"
      },
      "outputs": [],
      "source": [
        "# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can \n",
        "# be later called from within another tf.function. Necessary because a\n",
        "# @tf.function  decorated method cannot invoke a @tff.tf_computation.\n",
        "\n",
        "@tf.function\n",
        "def forward_pass(model, batch):\n",
        "  predicted_y = tf.nn.softmax(\n",
        "      tf.matmul(batch['x'], model['weights']) + model['bias'])\n",
        "  return -tf.reduce_mean(\n",
        "      tf.reduce_sum(\n",
        "          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))\n",
        "\n",
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "def batch_loss(model, batch):\n",
        "  return forward_pass(model, batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8K0UZHGnr8SB"
      },
      "source": [
        "As expected, computation `batch_loss` returns `float32` loss given the model and\n",
        "a single data batch. Note how the `MODEL_TYPE` and `BATCH_TYPE` have been lumped\n",
        "together into a 2-tuple of formal parameters; you can recognize the type of\n",
        "`batch_loss` as `(<MODEL_TYPE,BATCH_TYPE> -> float32)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4WXEAY8Nr89V"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'"
            ]
          },
          "execution_count": 13,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(batch_loss.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pAnt_UcdnvGa"
      },
      "source": [
        "As a sanity check, let's construct an initial model filled with zeros and\n",
        "compute the loss over the batch of data we visualized above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U8Ne8igan3os"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "2.3025854"
            ]
          },
          "execution_count": 14,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "initial_model = collections.OrderedDict(\n",
        "    weights=np.zeros([784, 10], dtype=np.float32),\n",
        "    bias=np.zeros([10], dtype=np.float32))\n",
        "\n",
        "sample_batch = federated_train_data[5][-1]\n",
        "\n",
        "batch_loss(initial_model, sample_batch)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ckigEAyDAWFV"
      },
      "source": [
        "Note that we feed the TFF computation with the initial model defined as a\n",
        "`dict`, even though the body of the Python function that defines it consumes\n",
        "model parameters as `model['weight']` and `model['bias']`. The arguments of the call\n",
        "to `batch_loss` aren't simply passed to the body of that function.\n",
        "\n",
        "\n",
        "What happens when we invoke `batch_loss`?\n",
        "The Python body of `batch_loss` has already been traced and serialized  in the above cell where it was defined.  TFF acts as the caller to `batch_loss`\n",
        "at the computation definition time, and as the target of invocation at the time\n",
        "`batch_loss` is invoked. In both roles, TFF serves as the bridge between TFF's\n",
        "abstract type system and Python representation types. At the invocation time,\n",
        "TFF will accept most standard Python container types (`dict`, `list`, `tuple`,\n",
        "`collections.namedtuple`, etc.) as concrete representations of abstract TFF\n",
        "tuples. Also, although as noted above, TFF computations formally only accept a\n",
        "single parameter, you can use the familiar Python call syntax with positional\n",
        "and/or keyword arguments in case where the type of the parameter is a tuple - it\n",
        "works as expected."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eB510nILYbId"
      },
      "source": [
        "### Gradient descent on a single batch\n",
        "\n",
        "Now, let's define a computation that uses this loss function to perform a single\n",
        "step of gradient descent. Note how in defining this function, we use\n",
        "`batch_loss` as a subcomponent. You can invoke a computation constructed with\n",
        "`tff.tf_computation` inside the body of another computation, though typically\n",
        "this is not necessary - as noted above, because serialization looses some\n",
        "debugging information, it is often preferable for more complex computations to\n",
        "write and test all the TensorFlow without the `tff.tf_computation` decorator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O4uaVxw3AyYS"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)\n",
        "def batch_train(initial_model, batch, learning_rate):\n",
        "  # Define a group of model variables and set them to `initial_model`. Must\n",
        "  # be defined outside the @tf.function.\n",
        "  model_vars = collections.OrderedDict([\n",
        "      (name, tf.Variable(name=name, initial_value=value))\n",
        "      for name, value in initial_model.items()\n",
        "  ])\n",
        "  optimizer = tf.keras.optimizers.SGD(learning_rate)\n",
        "\n",
        "  @tf.function\n",
        "  def _train_on_batch(model_vars, batch):\n",
        "    # Perform one step of gradient descent using loss from `batch_loss`.\n",
        "    with tf.GradientTape() as tape:\n",
        "      loss = forward_pass(model_vars, batch)\n",
        "    grads = tape.gradient(loss, model_vars)\n",
        "    optimizer.apply_gradients(\n",
        "        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))\n",
        "    return model_vars\n",
        "\n",
        "  return _train_on_batch(model_vars, batch)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y84gQsaohC38"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'"
            ]
          },
          "execution_count": 16,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(batch_train.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ID8xg9FCUL2A"
      },
      "source": [
        "When you invoke a Python function decorated with `tff.tf_computation` within the\n",
        "body of another such function, the logic of the inner TFF computation is\n",
        "embedded (essentially, inlined) in the logic of the outer one. As noted above,\n",
        "if you are writing both computations, it is likely preferable to make the inner\n",
        "function (`batch_loss` in this case) a regular Python or `tf.function` rather\n",
        "than a `tff.tf_computation`. However, here we illustrate that calling one\n",
        "`tff.tf_computation` inside another basically works as expected. This may be\n",
        "necessary if, for example, you do not have the Python code defining\n",
        "`batch_loss`, but only its serialized TFF representation.\n",
        "\n",
        "Now, let's apply this function a few times to the initial model to see whether\n",
        "the loss decreases."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8edcJTlXUULm"
      },
      "outputs": [],
      "source": [
        "model = initial_model\n",
        "losses = []\n",
        "for _ in range(5):\n",
        "  model = batch_train(model, sample_batch, 0.1)\n",
        "  losses.append(batch_loss(model, sample_batch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3n1onojT1zHv"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]"
            ]
          },
          "execution_count": 18,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQk4Ha8PU-3P"
      },
      "source": [
        "### Gradient descent on a sequence of local data\n",
        "\n",
        "Now, since `batch_train` appears to work, let's write a similar training\n",
        "function `local_train` that consumes the entire sequence of all batches from one\n",
        "user instead of just a single batch. The new computation will need to now\n",
        "consume `tff.SequenceType(BATCH_TYPE)` instead of `BATCH_TYPE`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EfPD5a6QVNXM"
      },
      "outputs": [],
      "source": [
        "LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)\n",
        "\n",
        "@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)\n",
        "def local_train(initial_model, learning_rate, all_batches):\n",
        "\n",
        "  # Mapping function to apply to each batch.\n",
        "  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)\n",
        "  def batch_fn(model, batch):\n",
        "    return batch_train(model, batch, learning_rate)\n",
        "\n",
        "  return tff.sequence_reduce(all_batches, initial_model, batch_fn)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sAhkS5yKUgjC"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'"
            ]
          },
          "execution_count": 20,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(local_train.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYT-SiopYBtH"
      },
      "source": [
        "There are quite a few details buried in this short section of code, let's go\n",
        "over them one by one.\n",
        "\n",
        "First, while we could have implemented this logic entirely in TensorFlow,\n",
        "relying on `tf.data.Dataset.reduce` to process the sequence similarly to how\n",
        "we've done it earlier, we've opted this time to express the logic in the glue\n",
        "language, as a `tff.federated_computation`. We've used the federated operator\n",
        "`tff.sequence_reduce` to perform the reduction.\n",
        "\n",
        "The operator `tff.sequence_reduce` is used similarly to\n",
        "`tf.data.Dataset.reduce`. You can think of it as essentially the same as\n",
        "`tf.data.Dataset.reduce`, but for use inside federated computations, which as\n",
        "you may remember, cannot contain TensorFlow code. It is a template operator with\n",
        "a formal parameter 3-tuple that consists of a *sequence* of `T`-typed elements,\n",
        "the initial state of the reduction (we'll refer to it abstractly as *zero*) of\n",
        "some type `U`, and the *reduction operator* of type `(<U,T> -> U)` that alters the\n",
        "state of the reduction by processing a single element. The result is the final\n",
        "state of the reduction, after processing all elements in a sequential order. In\n",
        "our example, the state of the reduction is the model trained on a prefix of the\n",
        "data, and the elements are data batches.\n",
        "\n",
        "Second, note that we have again used one computation (`batch_train`) as a\n",
        "component within another (`local_train`), but not directly. We can't use it as a\n",
        "reduction operator because it takes an additional parameter - the learning rate.\n",
        "To resolve this, we define an embedded federated computation `batch_fn` that\n",
        "binds to the `local_train`'s parameter `learning_rate` in its body. It is\n",
        "allowed for a child computation defined this way to capture a formal parameter\n",
        "of its parent as long as the child computation is not invoked outside the body\n",
        "of its parent. You can think of this pattern as an equivalent of\n",
        "`functools.partial` in Python.\n",
        "\n",
        "The practical implication of capturing `learning_rate` this way is, of course,\n",
        "that the same learning rate value is used across all batches.\n",
        "\n",
        "Now, let's try the newly defined local training function on the entire sequence\n",
        "of data from the same user who contributed the sample batch (digit `5`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EnWFLoZGcSby"
      },
      "outputs": [],
      "source": [
        "locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y0UXUqGk9zoF"
      },
      "source": [
        "Did it work? To answer this question, we need to implement evaluation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a8WDKu6WYy__"
      },
      "source": [
        "### Local evaluation\n",
        "\n",
        "Here's one way to implement local evaluation by adding up the losses across all data\n",
        "batches (we could have just as well computed the average; we'll leave it as an\n",
        "exercise for the reader)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0RiODuc6z7Ln"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)\n",
        "def local_eval(model, all_batches):\n",
        "  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.\n",
        "  return tff.sequence_sum(\n",
        "      tff.sequence_map(\n",
        "          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),\n",
        "          all_batches))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pH2XPEAKa4Dg"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'"
            ]
          },
          "execution_count": 23,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "str(local_eval.type_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "efX81HuE-BcO"
      },
      "source": [
        "Again, there are a few new elements illustrated by this code, let's go over them\n",
        "one by one.\n",
        "\n",
        "First, we have used two new federated operators for processing sequences:\n",
        "`tff.sequence_map` that takes a *mapping function* `T->U` and a *sequence* of\n",
        "`T`, and emits a sequence of `U` obtained by applying the mapping function\n",
        "pointwise, and `tff.sequence_sum` that just adds all the elements. Here, we map\n",
        "each data batch to a loss value, and then add the resulting loss values to\n",
        "compute the total loss.\n",
        "\n",
        "Note that we could have again used `tff.sequence_reduce`, but this wouldn't be\n",
        "the best choice - the reduction process is, by definition, sequential, whereas\n",
        "the mapping and sum can be computed in parallel. When given a choice, it's best\n",
        "to stick with operators that don't constrain implementation choices, so that\n",
        "when our TFF computation is compiled in the future to be deployed to a specific\n",
        "environment, one can take full advantage of all potential opportunities for a\n",
        "faster, more scalable, more resource-efficient execution.\n",
        "\n",
        "Second, note that just as in `local_train`, the component function we need\n",
        "(`batch_loss`) takes more parameters than what the federated operator\n",
        "(`tff.sequence_map`) expects, so we again define a partial, this time inline by\n",
        "directly wrapping a `lambda` as a `tff.federated_computation`. Using wrappers\n",
        "inline with a function as an argument is the recommended way to use\n",
        "`tff.tf_computation` to embed TensorFlow logic in TFF.\n",
        "\n",
        "Now, let's see whether our training worked."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vPw6JSVf5q_x"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025854\n",
            "locally_trained_model loss = 0.4348469\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', local_eval(initial_model,\n",
        "                                         federated_train_data[5]))\n",
        "print('locally_trained_model loss =',\n",
        "      local_eval(locally_trained_model, federated_train_data[5]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Tvu70cnBsUf"
      },
      "source": [
        "Indeed, the loss decreased. But what happens if we evaluated it on another\n",
        "user's data?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gjF0NYAj5wls"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025854\n",
            "locally_trained_model loss = 74.50075\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', local_eval(initial_model,\n",
        "                                         federated_train_data[0]))\n",
        "print('locally_trained_model loss =',\n",
        "      local_eval(locally_trained_model, federated_train_data[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7WPumnRTBzUs"
      },
      "source": [
        "As expected, things got worse. The model was trained to recognize `5`, and has\n",
        "never seen a `0`. This brings the question - how did the local training impact\n",
        "the quality of the model from the global perspective?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QJnL2mQRZKTO"
      },
      "source": [
        "### Federated evaluation\n",
        "\n",
        "This is the point in our journey where we finally circle back to federated types\n",
        "and federated computations - the topic that we started with. Here's a pair of\n",
        "TFF types definitions for the model that originates at the server, and the data\n",
        "that remains on the clients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LjGGhpoEBh_6"
      },
      "outputs": [],
      "source": [
        "SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)\n",
        "CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gTXV2-jZtE3"
      },
      "source": [
        "With all the definitions introduced so far, expressing federated evaluation in\n",
        "TFF is a one-liner - we distribute the model to clients, let each client invoke\n",
        "local evaluation on its local portion of data, and then average out the loss.\n",
        "Here's one way to write this."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2zChEPzEBx4T"
      },
      "outputs": [],
      "source": [
        "@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)\n",
        "def federated_eval(model, data):\n",
        "  return tff.federated_mean(\n",
        "      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IWcNONNWaE0N"
      },
      "source": [
        "We've already seen examples of `tff.federated_mean` and `tff.federated_map`\n",
        "in simpler scenarios, and at the intuitive level, they work as expected, but\n",
        "there's more in this section of code than meets the eye, so let's go over it\n",
        "carefully.\n",
        "\n",
        "First, let's break down the *let each client invoke local evaluation on its\n",
        "local portion of data* part. As you may recall from the preceding sections,\n",
        "`local_eval` has a type signature of the form `(<MODEL_TYPE, LOCAL_DATA_TYPE> ->\n",
        "float32)`.\n",
        "\n",
        "The federated operator `tff.federated_map` is a template that accepts as a\n",
        "parameter a 2-tuple that consists of the *mapping function* of some type `T->U`\n",
        "and a federated value of type `{T}@CLIENTS` (i.e., with member constituents of\n",
        "the same type as the parameter of the mapping function), and returns a result of\n",
        "type `{U}@CLIENTS`.\n",
        "\n",
        "Since we're feeding `local_eval` as a mapping function to apply on a per-client\n",
        "basis, the second argument should be of a federated type `{<MODEL_TYPE,\n",
        "LOCAL_DATA_TYPE>}@CLIENTS`, i.e., in the nomenclature of the preceding sections,\n",
        "it should be a federated tuple. Each client should hold a full set of arguments\n",
        "for `local_eval` as a member consituent. Instead, we're feeding it a 2-element\n",
        "Python `list`. What's happening here?\n",
        "\n",
        "Indeed, this is an example of an *implicit type cast* in TFF, similar to\n",
        "implicit type casts you may have encountered elsewhere, e.g., when you feed an\n",
        "`int` to a function that accepts a `float`. Implicit casting is used scarcily at\n",
        "this point, but we plan to make it more pervasive in TFF as a way to minimize\n",
        "boilerplate.\n",
        "\n",
        "The implicit cast that's applied in this case is the equivalence between\n",
        "federated tuples of the form `{<X,Y>}@Z`, and tuples of federated values\n",
        "`<{X}@Z,{Y}@Z>`. While formally, these two are different type signatures,\n",
        "looking at it from the programmers's perspective, each device in `Z` holds two\n",
        "units of data `X` and `Y`. What happens here is not unlike `zip` in Python, and\n",
        "indeed, we offer an operator `tff.federated_zip` that allows you to perform such\n",
        "conversions explicity. When the `tff.federated_map` encounters a tuple as a\n",
        "second argument, it simply invokes `tff.federated_zip` for you.\n",
        "\n",
        "Given the above, you should now be able to recognize the expression\n",
        "`tff.federated_broadcast(model)` as representing a value of TFF type\n",
        "`{MODEL_TYPE}@CLIENTS`, and `data` as a value of TFF type\n",
        "`{LOCAL_DATA_TYPE}@CLIENTS` (or simply `CLIENT_DATA_TYPE`), the two getting\n",
        "filtered together through an implicit `tff.federated_zip` to form the second\n",
        "argument to `tff.federated_map`.\n",
        "\n",
        "The operator `tff.federated_broadcast`, as you'd expect, simply transfers data\n",
        "from the server to the clients.\n",
        "\n",
        "Now, let's see how our local training affected the average loss in the system."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tbmtJItcn94j"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model loss = 23.025852\n",
            "locally_trained_model loss = 54.432625\n"
          ]
        }
      ],
      "source": [
        "print('initial_model loss =', federated_eval(initial_model,\n",
        "                                             federated_train_data))\n",
        "print('locally_trained_model loss =',\n",
        "      federated_eval(locally_trained_model, federated_train_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LQi2rGX_fK7i"
      },
      "source": [
        "Indeed, as expected, the loss has increased. In order to improve the model for\n",
        "all users, we'll need to train in on everyone's data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkw9f59qfS7o"
      },
      "source": [
        "### Federated training\n",
        "\n",
        "The simplest way to implement federated training is to locally train, and then\n",
        "average the models. This uses the same building blocks and patters we've already\n",
        "discussed, as you can see below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mBOC4uoG6dd-"
      },
      "outputs": [],
      "source": [
        "SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)\n",
        "\n",
        "\n",
        "@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,\n",
        "                           CLIENT_DATA_TYPE)\n",
        "def federated_train(model, learning_rate, data):\n",
        "  return tff.federated_mean(\n",
        "      tff.federated_map(local_train, [\n",
        "          tff.federated_broadcast(model),\n",
        "          tff.federated_broadcast(learning_rate), data\n",
        "      ]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z2vACMsQjzO1"
      },
      "source": [
        "Note that in the full-featured implementation of Federated Averaging provided by\n",
        "`tff.learning`, rather than averaging the models, we prefer to average model\n",
        "deltas, for a number of reasons, e.g., the ability to clip the update norms,\n",
        "for compression, etc.\n",
        "\n",
        "Let's see whether the training works by running a few rounds of training and\n",
        "comparing the average loss before and after."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NLx-3rLs9jGY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "round 0, loss=21.60552406311035\n",
            "round 1, loss=20.365678787231445\n",
            "round 2, loss=19.27480125427246\n",
            "round 3, loss=18.31110954284668\n",
            "round 4, loss=17.45725440979004\n"
          ]
        }
      ],
      "source": [
        "model = initial_model\n",
        "learning_rate = 0.1\n",
        "for round_num in range(5):\n",
        "  model = federated_train(model, learning_rate, federated_train_data)\n",
        "  learning_rate = learning_rate * 0.9\n",
        "  loss = federated_eval(model, federated_train_data)\n",
        "  print('round {}, loss={}'.format(round_num, loss))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z0VjSLQzlUIp"
      },
      "source": [
        "For completeness, let's now also run on the test data to confirm that our model\n",
        "generalizes well."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZaZT45yFMOaM"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial_model test loss = 22.795593\n",
            "trained_model test loss = 17.278767\n"
          ]
        }
      ],
      "source": [
        "print('initial_model test loss =',\n",
        "      federated_eval(initial_model, federated_test_data))\n",
        "print('trained_model test loss =', federated_eval(model, federated_test_data))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pxlHHwLGlgFB"
      },
      "source": [
        "This concludes our tutorial.\n",
        "\n",
        "Of course, our simplified example doesn't reflect a number of things you'd need\n",
        "to do in a more realistic scenario - for example, we haven't computed metrics\n",
        "other than loss. We encourage you to study\n",
        "[the implementation](https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/federated_averaging.py)\n",
        "of federated averaging in `tff.learning` as a more complete example, and as a\n",
        "way to demonstrate some of the coding practices we'd like to encourage."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_federated_algorithms_2.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
